from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tool import common
from train import trainModel
import threading
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.compat.v1 import keras
from tensorflow.compat.v1.keras import layers
from tensorflow.compat.v1.keras import optimizers
from tensorflow.compat.v1.keras.callbacks import LearningRateScheduler
from tensorflow.compat.v1.keras.callbacks import TensorBoard


def tic_train_type_thread():

    def tf_train():
        
        model = tf.compat.v1.keras.Sequential(
            [
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu",
                              input_shape=[224, 224, 1]),
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME", name="m1"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME", name="m2"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME", name="m3"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME", name="m4"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME", name="m5"),
                layers.Flatten(),
                layers.Dense(4096, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(4096, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(17, activation='softmax', name='output')
            ])
        
        model.summary()
        
        
        learning_rate = 1e-3;
        epoch_num = 2
        
        
        def scheduler(epoch):
            if epoch < epoch_num * 0.5:
                return learning_rate
            if epoch < epoch_num * 0.8:
                return learning_rate * 0.5
            return learning_rate * 0.1
        
        
        sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        change_lr = LearningRateScheduler(scheduler)
        model.compile(sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        
        logdir = os.path.join("./" + common.model_dir)
        tensorboard = TensorBoard(log_dir=logdir, write_graph=True)
        
        cp_callback = tf.compat.v1.keras.callbacks.ModelCheckpoint(logdir,save_weights_only=True, 
                                                 verbose=5, period=3)
        
        model.fit(common.dataProcess_tic_type.get_dataset_train(), epochs=epoch_num, steps_per_epoch=107, callbacks=[change_lr,  cp_callback])
        
        model.evaluate(common.dataProcess_tic_type.get_dataset_test(), steps=33)
        
        model.predict(common.dataProcess_tic_type.get_dataset_test(), steps=33)
        
        model.save('./model')
        print("save1")
        
        model.save('./model/abc.h5')
        print("save2")

    th = threading.Thread(target=tf_train)
    th.start()